Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move to ChainRulesCore v1.0 (in OptingOut) #1035

Merged
merged 21 commits into from
Jul 27, 2021
Merged

Move to ChainRulesCore v1.0 (in OptingOut) #1035

merged 21 commits into from
Jul 27, 2021

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Jul 20, 2021

This PR will make the changes needed to support ChainRulesCore v1.0
A number of commits will be added to it to fix each breaking change.

Right now I am aware of 2 such changes

Copy link
Collaborator

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally the mechanism seems fine. I guess that args_T is always just concrete types. I do still have some questions on how the functions work - but seems fine overall

# Now try opting out After we have already used it
@opt_out ChainRulesCore.rrule(::typeof(oa_id), x::Real)
oa_id_rrule_hitcount[] = 0
oa_id_outer(x) = sum(oa_id(x))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we redefine the function here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a mistake

src/compiler/chainrules.jl Show resolved Hide resolved
src/compiler/chainrules.jl Show resolved Hide resolved
src/compiler/chainrules.jl Show resolved Hide resolved
end

do_not_use_rrule = matching_cr_sig(no_rrule_m, rrule_m)
if do_not_use_rrule
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this meant to do decomposition?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is has_chain_rrule and if it returns false,... then pullback will indeed end up calling generate_pullback_via_decomposition.

@mzgubic
Copy link
Collaborator

mzgubic commented Jul 26, 2021

minimal failing example for the remaining test failure is:

    using Zygote 

    W = ones(Float32, 3)
    ps = Zygote.Params([W, ])

    gs = gradient(ps) do
        p, pb = pullback(ps) do
            sum(W)
        end
        g = pb(p)
        sum(g[W])
    end
(stacktrace)

ERROR: Can't differentiate foreigncall expression
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] Pullback
    @ ./iddict.jl:87 [inlined]
  [3] (::typeof((get)))(Δ::Nothing)
    @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
  [4] Pullback
    @ ~/JuliaEnvs/Zygote.jl/src/lib/lib.jl:68 [inlined]
  [5] (::typeof((accum_global)))(Δ::Nothing)
    @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/JuliaEnvs/Zygote.jl/src/lib/lib.jl:79 [inlined]
  [7] (::typeof((λ)))(Δ::Nothing)
    @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
  [9] (::typeof((λ)))(Δ::Nothing)
    @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
 [10] getindex
    @ ./tuple.jl:29 [inlined]
 [11] (::typeof((λ)))(Δ::Nothing)
    @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/JuliaEnvs/Zygote.jl/src/compiler/interface.jl:348 [inlined]
 [13] (::typeof((λ)))(Δ::Nothing)
    @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
 [14] Pullback
    @ ./REPL[5]:5 [inlined]
 [15] (::typeof((#5)))(Δ::Float32)
    @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0
 [16] (::Zygote.var"#90#91"{Params, typeof((#5)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface.jl:348
 [17] gradient(f::Function, args::Params)
    @ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface.jl:76
 [18] top-level scope
    @ REPL[5]:1

where the failing expression is
ex = :($(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)))

Removing Params results in no error being thrown, as does only doing first order AD. So combining Params and second order AD seems to have gone bad because of ProjectTo or possibly opting out of rules.

@oxinabox
Copy link
Member Author

minimal failing example for the remaining test failure is:

That fails to me with the current released version of Zygote for me also.

      Status `/tmp/jl_QkhRV2/Project.toml`
  [082447d4] ChainRules v0.8.23
  [d360d2e6] ChainRulesCore v0.10.13
  [e88e6eb3] Zygote v0.6.17

@DhairyaLGandhi
Copy link
Member

Check with #823?

@oxinabox
Copy link
Member Author

oxinabox commented Jul 26, 2021

Check with #823?

Yes it works on that branch.
I will see if it works in this branch with that one merged into it
Edit: yes if i merge that into this one it works here also,
but a bunch of other things break, just like they do in #823

@oxinabox
Copy link
Member Author

oxinabox commented Jul 26, 2021

Better MWE: that works on current release, but fails on this PR

using ChainRulesCore
using Test
using Zygote

@testset "Params nesting" begin
    struct Dense{F,T,S}
      W::T
      b::S
      σ::F
    end
  
    (d::Dense)(x) = d.σ.(d.W * x .+ d.b)
    d = Dense(ones(Float32, 3,3), zeros(Float32, 3), identity)
    ps = Zygote.Params([d.W, d.b])
    r = ones(Float32, 3,3)
  
    gs = gradient(ps) do
      p, pb = pullback(ps) do
        sum(d(r))
      end
      g = pb(p)
      sum(g[d.W]) # + sum(g[d.b])
    end
end

That particular issue is fixed by JuliaDiff/ChainRulesCore.jl#414

but other issues still remain.

@DhairyaLGandhi
Copy link
Member

Probably worth it to merge #823 with a cleanup then.

@oxinabox
Copy link
Member Author

oxinabox commented Jul 26, 2021

Probably worth it to merge #823 with a cleanup then.

No, i fixed it to not need that.
I think #823 fixed it kinda by coincidence.
Or maybe it didn't even fix the real thing at all, it was a bad MWE.


This is now passing all tests locally.
It is just blocked by JuliaMath/SpecialFunctions.jl#335
once that is in this should be ready to go.

And by JuliaDiff/ForwardDiff.jl#538
without resolving that that we can't do Hessians it seems

@@ -275,7 +304,7 @@ end
ZygoteRuleConfig(), my_namedtuple, 1., 2., 3.; rrule_f=rrule_via_ad
)
test_rrule(
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, "str"), 3.; rrule_f=rrule_via_ad
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, 2.4), 3.; rrule_f=rrule_via_ad
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to change this due to an issue in ChainRulesTestUtils
JuliaDiff/ChainRulesTestUtils.jl#194

The Zygote and ChainRules code handles the string fine,
but ChainRulesTestUtils kinda freaks out about it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can keep it as a gradtest then? I'm not super sure if CRTU is expected to be a dependency for every AD test suite.

Copy link
Member Author

@oxinabox oxinabox Jul 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure gradtest can't handle it either -- gradtest can't handle namedtuples containing tuples.

If you think it is important to have, to test what rrule_via_ad does when confronted with a string,
I can add a test specifically for that.

I'm not super sure if CRTU is expected to be a dependency for every AD test suite.

It is, if you have rrule_via_ad overloaded. (or rrule)
Then you can use it to test all things.
It's a more robust version of gradtest that can handle more types, and doesn't e.g. get tricked by antisymmetric errors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its good to have the coverage over different data types, but I guess its going to need a future PR. I can already sense some SMILES like application making use of strings outside of embeddings.

@@ -18,7 +18,7 @@ using Zygote, Test, LinearAlgebra
@test gradient(x -> real(logabsdet(x)[1]), [1 2im; 3im 4])[1] ≈ [4 3im; 2im 1]/10

# https://github.com/FluxML/Zygote.jl/issues/705
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ im .* exp.(1:3)
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ real(im .* exp.(1:3))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because ChainRules takes embedded subspaces seriously.
Derivative of a real array can not be imaginary.

The thing that #705 was worried about is still fixed

@@ -449,12 +449,12 @@ end
@test pullback(type_test)[1] == Complex{<:Real}

@testset "Pairs" begin
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0
Copy link
Member Author

@oxinabox oxinabox Jul 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because we now take embedded sub-spaces seriously.
Integers are always considered to be a subspace of Floats.
Not just went it happens by coincidence

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be compiling to integer code anyway?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Integers are not a good type to use to represent tangents.
Because if you are going to do gradeient decent you are going to apply a learning rate like 0.1*dx.
So we call float on them to get the corresponding floating point type.
(Unless it is a index or something, then it would be NoTangent())

Arguably we could keep integers here, but encouraging people to use integers to repressent continous values that just so happen do be integers feels like not taking subspace types seriously
🤷

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a matter of letting the language take charge of these things. It's possible to work with just ints.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be moved to a issue on ChainRulesCore.jl?
Its not going to be resolved in this PR.

@@ -81,7 +81,7 @@ end
@test gradient(xs ->sum(xs .^ _pow), [4, -1]) == ([_pow*4^9, -10],)

@test gradient(x -> real((1+3im) * x^2), 5+7im) == (-32 - 44im,)
@test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] ≈ (-234 + 2im)*log(5 - 7im)
@test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] ≈ real((-234 + 2im)*log(5 - 7im))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again: primal was real so too much be it's derivative

@oxinabox oxinabox changed the title WIP: Move to ChainRulesCore v1.0 (in OptingOut) Move to ChainRulesCore v1.0 (in OptingOut) Jul 26, 2021
Project.toml Outdated Show resolved Hide resolved
@@ -275,7 +304,7 @@ end
ZygoteRuleConfig(), my_namedtuple, 1., 2., 3.; rrule_f=rrule_via_ad
)
test_rrule(
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, "str"), 3.; rrule_f=rrule_via_ad
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, 2.4), 3.; rrule_f=rrule_via_ad
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can keep it as a gradtest then? I'm not super sure if CRTU is expected to be a dependency for every AD test suite.

@@ -449,12 +449,12 @@ end
@test pullback(type_test)[1] == Complex{<:Real}

@testset "Pairs" begin
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be compiling to integer code anyway?

test/utils.jl Outdated Show resolved Hide resolved
Co-authored-by: Dhairya Gandhi <[email protected]>
src/compiler/interface.jl Show resolved Hide resolved
src/compiler/chainrules.jl Outdated Show resolved Hide resolved
# rrule: specific, no_rrule: fallback => !matches => do use rrule, as haven't opted out.
# rrule: fallback, no_rrule: specific => IMPOSSIBLE, every no_rule us identical to some rrule
# rrule: specific, no_rrule: specific => matches => do not use rrule as opted out
# rrule: specific, no_rrule: general => !matches => do use rrule as a more specific rrule takes preciedent over more general opted out
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So its the kind of complement of @nograd f(x::SomeType, y::SomeOtherType) where we can still get grads for some methods, and not for others? Kind of defining more specific rrules and dispatching to them, but doing so manually here, rather than letting Julia do it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is like defining a more specific rrule where that more specific rrule is "Let AD work it out".
But we can't use rrule_via_ad for this because you hit a stackoverflow.

It is not really like @nograd (@non_differentiable is like @nograd) except that both participate in rrule dispatch via specificity

src/compiler/chainrules.jl Outdated Show resolved Hide resolved
@@ -275,7 +304,7 @@ end
ZygoteRuleConfig(), my_namedtuple, 1., 2., 3.; rrule_f=rrule_via_ad
)
test_rrule(
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, "str"), 3.; rrule_f=rrule_via_ad
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, 2.4), 3.; rrule_f=rrule_via_ad
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its good to have the coverage over different data types, but I guess its going to need a future PR. I can already sense some SMILES like application making use of strings outside of embeddings.

@oxinabox oxinabox merged commit 29e3b60 into master Jul 27, 2021
@oxinabox oxinabox deleted the ox/cr1 branch July 27, 2021 18:11
@mcabbott mcabbott mentioned this pull request Jul 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants